"""
Test combined usage of custom LangChain LLM and retriever.

This tests the real-world scenario where users provide both custom LLMs
and custom retrievers to create a fully customized research pipeline.
"""

import os
import pytest
from unittest.mock import patch
from typing import List, Any, Optional
from langchain_core.retrievers import BaseRetriever, Document
from langchain_core.language_models.llms import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from pydantic import Field

from local_deep_research.api.research_functions import (
    quick_summary,
    detailed_research,
)


class CompanyKnowledgeRetriever(BaseRetriever):
    """Simulates a company's internal knowledge base retriever."""

    knowledge_domain: str = Field(default="general")
    documents: List[Document] = Field(default_factory=list)

    def __init__(self, knowledge_domain: str = "general", **kwargs):
        """Initialize with a specific knowledge domain."""
        super().__init__(knowledge_domain=knowledge_domain, **kwargs)
        self.documents = self._load_domain_documents()

    def _load_domain_documents(self) -> List[Document]:
        """Load documents based on domain."""
        if self.knowledge_domain == "engineering":
            return [
                Document(
                    page_content="Our engineering best practices emphasize code review, testing, and CI/CD.",
                    metadata={
                        "source": "eng_handbook",
                        "section": "best_practices",
                        "last_updated": "2024-01-10",
                    },
                ),
                Document(
                    page_content="We use microservices architecture with Kubernetes for scalability.",
                    metadata={
                        "source": "eng_handbook",
                        "section": "architecture",
                        "last_updated": "2023-12-15",
                    },
                ),
                Document(
                    page_content="Security practices include regular audits and automated vulnerability scanning.",
                    metadata={
                        "source": "eng_handbook",
                        "section": "security",
                        "last_updated": "2024-01-20",
                    },
                ),
            ]
        elif self.knowledge_domain == "product":
            return [
                Document(
                    page_content="Product roadmap focuses on AI-driven features and user experience.",
                    metadata={
                        "source": "product_docs",
                        "section": "roadmap",
                        "quarter": "Q1-2024",
                    },
                ),
                Document(
                    page_content="User research shows demand for better integration capabilities.",
                    metadata={
                        "source": "product_docs",
                        "section": "research",
                        "date": "2024-01-05",
                    },
                ),
            ]
        else:
            return [
                Document(
                    page_content="Company mission is to democratize AI technology.",
                    metadata={"source": "company_docs", "type": "mission"},
                ),
            ]

    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """Get relevant documents based on query."""
        relevant = []
        query_lower = query.lower()

        for doc in self.documents:
            if any(
                word in doc.page_content.lower() for word in query_lower.split()
            ):
                relevant.append(doc)

        return relevant if relevant else self.documents[:1]

    async def _aget_relevant_documents(
        self, query: str, **kwargs
    ) -> List[Document]:
        return self._get_relevant_documents(query, **kwargs)


class CompanyCustomLLM(LLM):
    """Simulates a company's custom fine-tuned LLM."""

    model_version: str = Field(default="v1")
    use_company_style: bool = Field(default=True)

    def __init__(
        self,
        model_version: str = "v1",
        use_company_style: bool = True,
        **kwargs,
    ):
        """Initialize with company-specific parameters."""
        super().__init__(
            model_version=model_version,
            use_company_style=use_company_style,
            **kwargs,
        )

    @property
    def _llm_type(self) -> str:
        return f"company_llm_{self.model_version}"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Generate response with company-specific style."""
        # Simulate different responses based on prompt content
        response = self._generate_response(prompt)

        if self.use_company_style:
            # Add company-specific formatting
            response = f"Based on our analysis: {response}"

        return response

    def _generate_response(self, prompt: str) -> str:
        """Generate response based on prompt content."""
        prompt_lower = prompt.lower()

        if "engineering" in prompt_lower or "technical" in prompt_lower:
            return "Our engineering practices emphasize scalability, security, and maintainability. We follow industry best practices with a focus on microservices architecture and continuous integration."
        elif "product" in prompt_lower or "roadmap" in prompt_lower:
            return "The product strategy centers on AI-driven features that enhance user productivity. Key focus areas include seamless integrations and intuitive user interfaces."
        elif "security" in prompt_lower:
            return "Security is paramount in our systems. We implement defense-in-depth strategies, regular security audits, and automated vulnerability scanning."
        elif "question" in prompt_lower and "generate" in prompt_lower:
            # For question generation
            return "1. What are the key architectural decisions?\n2. How do we ensure security compliance?\n3. What are the integration requirements?"
        else:
            return f"Our comprehensive approach addresses the query about {prompt[:50]}... with focus on quality and innovation."

    @property
    def _identifying_params(self) -> dict:
        return {
            "model_version": self.model_version,
            "use_company_style": self.use_company_style,
        }


@pytest.mark.skipif(
    os.environ.get("CI") == "true"
    or os.environ.get("GITHUB_ACTIONS") == "true",
    reason="Langchain integration tests skipped in CI - testing advanced features",
)
class TestCombinedLLMRetriever:
    """Test suite for combined custom LLM and retriever usage."""

    @pytest.fixture
    def settings_snapshot(self):
        """Create settings snapshot for testing."""
        return {
            "llm.provider": {"value": "none", "type": "str"},
            "llm.model": {"value": "company_llm_v1", "type": "str"},
            "llm.temperature": {"value": 0.7, "type": "float"},
            "llm.custom.api_key": {
                "value": "company-internal-key",
                "type": "str",
            },
            "llm.custom.endpoint": {
                "value": "https://internal-llm.company.com",
                "type": "str",
            },
            "research.iterations": {"value": 2, "type": "int"},
            "research.questions_per_iteration": {"value": 3, "type": "int"},
            "research.local_context": {"value": 2000, "type": "int"},
            "research.web_context": {"value": 2000, "type": "int"},
            "llm.context_window_unrestricted": {"value": False, "type": "bool"},
            "llm.context_window_size": {"value": 8192, "type": "int"},
            "llm.local_context_window_size": {"value": 4096, "type": "int"},
            "llm.supports_max_tokens": {"value": True, "type": "bool"},
            "llm.max_tokens": {"value": 4096, "type": "int"},
            "rate_limiting.llm_enabled": {"value": False, "type": "bool"},
            "search.tool": {"value": "auto", "type": "str"},
            "search.max_results": {"value": 10, "type": "int"},
            "search.cross_engine_max_results": {"value": 100, "type": "int"},
            "search.cross_engine_use_reddit": {"value": False, "type": "bool"},
            "search.cross_engine_min_date": {"value": None, "type": "str"},
            "search.region": {"value": "us", "type": "str"},
            "search.time_period": {"value": "y", "type": "str"},
            "search.safe_search": {"value": True, "type": "bool"},
            "search.snippets_only": {"value": True, "type": "bool"},
            "search.search_language": {"value": "English", "type": "str"},
            "search.max_filtered_results": {"value": 20, "type": "int"},
        }

    def test_engineering_research_with_custom_components(
        self, settings_snapshot
    ):
        """Test engineering-focused research with custom LLM and retriever."""
        # Create custom components
        eng_retriever = CompanyKnowledgeRetriever(
            knowledge_domain="engineering"
        )
        custom_llm = CompanyCustomLLM(
            model_version="v1", use_company_style=True
        )

        result = quick_summary(
            query="What are our engineering best practices for microservices?",
            research_id=10001,
            llms={"custom": custom_llm},
            retrievers={"eng_kb": eng_retriever},
            search_tool="eng_kb",
            settings_snapshot=settings_snapshot,
            iterations=1,
            questions_per_iteration=2,
        )

        assert result["research_id"] == 10001
        assert "engineering" in result["summary"].lower()
        assert "Based on our analysis" in result["summary"]  # Company style
        assert len(result["sources"]) > 0

        # Verify sources are from engineering knowledge base
        for source in result["sources"]:
            if hasattr(source, "metadata"):
                assert source.metadata.get("source") == "eng_handbook"

    def test_multi_domain_research(self, settings_snapshot):
        """Test research across multiple company domains."""
        # Create retrievers for different domains
        eng_retriever = CompanyKnowledgeRetriever(
            knowledge_domain="engineering"
        )
        product_retriever = CompanyKnowledgeRetriever(
            knowledge_domain="product"
        )
        general_retriever = CompanyKnowledgeRetriever(
            knowledge_domain="general"
        )

        # Create custom LLM
        custom_llm = CompanyCustomLLM(model_version="v2")

        with patch(
            "random.randint",
            return_value=20002,
        ):
            result = detailed_research(
                query="How do our engineering practices align with product roadmap?",
                llms={"custom": custom_llm},
                retrievers={
                    "engineering": eng_retriever,
                    "product": product_retriever,
                    "general": general_retriever,
                },
                search_tool="auto",  # Use all retrievers
                settings_snapshot=settings_snapshot,
                iterations=2,
                questions_per_iteration=3,
            )

        assert result["research_id"] == 20002
        assert "engineering" in result["summary"].lower()
        assert "product" in result["summary"].lower()

        # Check that we have sources from multiple domains
        source_types = set()
        for source in result["sources"]:
            if hasattr(source, "metadata"):
                source_types.add(source.metadata.get("source"))

        # Should have sources from at least 2 different domains
        assert len(source_types) >= 2

    def test_custom_llm_factory_pattern(self, settings_snapshot):
        """Test using a factory pattern for creating custom components."""

        def create_company_llm(settings_snapshot):
            """Factory function that uses settings to configure LLM."""
            # Extract company-specific settings
            api_key = settings_snapshot.get("llm.custom.api_key", {}).get(
                "value"
            )
            endpoint = settings_snapshot.get("llm.custom.endpoint", {}).get(
                "value"
            )
            model_version = settings_snapshot.get("llm.model", {}).get(
                "value", "v1"
            )

            # Validate settings
            assert api_key == "company-internal-key"
            assert endpoint == "https://internal-llm.company.com"

            # Create LLM with settings
            version = (
                model_version.split("_")[-1] if "_" in model_version else "v1"
            )
            return CompanyCustomLLM(model_version=version)

        def create_company_retriever(domain: str, settings_snapshot):
            """Factory function for creating domain-specific retrievers."""
            # Could use settings to configure retriever
            # For example, different endpoints for different domains
            return CompanyKnowledgeRetriever(knowledge_domain=domain)

        # Use factories to create components
        llm = create_company_llm(settings_snapshot)
        retriever = create_company_retriever("engineering", settings_snapshot)

        with patch(
            "random.randint",
            return_value=30003,
        ):
            result = quick_summary(
                query="Security best practices",
                llms={"corporate": llm},
                retrievers={"company_kb": retriever},
                search_tool="company_kb",
                settings_snapshot=settings_snapshot,
                iterations=1,
            )

        assert result["research_id"] == 30003
        assert "security" in result["summary"].lower()

    def test_fallback_handling(self, settings_snapshot):
        """Test fallback when custom components partially fail."""

        class UnreliableRetriever(BaseRetriever):
            """Retriever that sometimes fails."""

            call_count: int = Field(default=0)

            def __init__(self, **kwargs):
                super().__init__(**kwargs)
                self.call_count = 0

            def _get_relevant_documents(
                self, query: str, **kwargs
            ) -> List[Document]:
                self.call_count += 1
                if self.call_count == 1:
                    # Fail on first call
                    raise RuntimeError("Temporary retriever failure")
                else:
                    # Success on retry
                    return [
                        Document(
                            page_content="Fallback content about the query",
                            metadata={
                                "source": "fallback",
                                "retry": self.call_count,
                            },
                        )
                    ]

            async def _aget_relevant_documents(
                self, query: str, **kwargs
            ) -> List[Document]:
                return self._get_relevant_documents(query, **kwargs)

        unreliable_retriever = UnreliableRetriever()
        custom_llm = CompanyCustomLLM()

        # First attempt should fail
        with pytest.raises(RuntimeError, match="Temporary retriever failure"):
            quick_summary(
                query="Test query",
                llms={"custom": custom_llm},
                retrievers={"unreliable": unreliable_retriever},
                search_tool="unreliable",
                settings_snapshot=settings_snapshot,
                iterations=1,
            )

    def test_performance_monitoring(self, settings_snapshot):
        """Test that custom components can be monitored for performance."""

        class MonitoredLLM(CompanyCustomLLM):
            """LLM that tracks performance metrics."""

            call_count: int = Field(default=0)
            total_tokens: int = Field(default=0)
            response_times: List[float] = Field(default_factory=list)

            def __init__(self, **kwargs):
                super().__init__(**kwargs)
                self.call_count = 0
                self.total_tokens = 0
                self.response_times = []

            def _call(self, prompt: str, **kwargs) -> str:
                import time

                start_time = time.time()

                self.call_count += 1
                response = super()._call(prompt, **kwargs)

                # Simulate token counting
                self.total_tokens += len(prompt.split()) + len(response.split())

                # Track response time
                self.response_times.append(time.time() - start_time)

                return response

        monitored_llm = MonitoredLLM()
        retriever = CompanyKnowledgeRetriever("engineering")

        with patch(
            "random.randint",
            return_value=40004,
        ):
            result = quick_summary(
                query="Performance testing query",
                llm=monitored_llm,
                retrievers={"eng": retriever},
                search_tool="eng",
                settings_snapshot=settings_snapshot,
                iterations=2,
                questions_per_iteration=2,
            )

        assert result["research_id"] == 40004

        # Verify monitoring worked
        assert monitored_llm.call_count > 0
        assert monitored_llm.total_tokens > 0
        assert len(monitored_llm.response_times) == monitored_llm.call_count

    def test_real_world_scenario(self, settings_snapshot):
        """Test a realistic scenario with multiple components and error handling."""
        # Create components
        eng_retriever = CompanyKnowledgeRetriever("engineering")
        product_retriever = CompanyKnowledgeRetriever("product")
        custom_llm = CompanyCustomLLM(
            model_version="v2", use_company_style=True
        )

        # Mock web search for hybrid approach
        web_results = [
            {
                "url": "https://external.com/best-practices",
                "title": "Industry Best Practices",
                "content": "Industry standards for microservices include...",
                "source": "web",
            }
        ]

        with patch(
            "local_deep_research.config.search_config.get_search",
            return_value=web_results,
        ):
            with patch(
                "random.randint",
                return_value=50005,
            ):
                result = detailed_research(
                    query="Compare our engineering practices with industry standards",
                    llms={"custom": custom_llm},
                    retrievers={
                        "internal_eng": eng_retriever,
                        "internal_product": product_retriever,
                    },
                    search_tool="meta",
                    meta_search_config={
                        "retrievers": ["internal_eng", "internal_product"],
                        "engines": ["wikipedia"],  # Also search web
                    },
                    settings_snapshot=settings_snapshot,
                    iterations=3,
                    questions_per_iteration=4,
                    enable_followup_questions=True,
                )

        assert result["research_id"] == 50005
        assert "engineering" in result["summary"].lower()
        assert (
            "Based on our analysis" in result["summary"]
        )  # Company style preserved

        # Should have both internal and external sources
        internal_sources = 0
        external_sources = 0

        for source in result["sources"]:
            if hasattr(source, "metadata"):
                if source.metadata.get("source") in [
                    "eng_handbook",
                    "product_docs",
                ]:
                    internal_sources += 1
            elif isinstance(source, dict) and source.get("source") == "web":
                external_sources += 1

        assert internal_sources > 0  # Has internal knowledge
        assert external_sources > 0  # Has external perspective
